import random

import pandas as pd
import numpy as np

from collections import defaultdict

np.random.seed(seed=1990)
random.seed(a=1990)

def get_partitions_sizes(skewness_factor, partitions, num_examples, descending=True):
	# Find smallest bin size based on the factor difference between two consecutive partitions.
	# CAUTION pidx starts from 0, thus +1 (need it for factorized bin sizes)
	factorized_bin_sizes = [np.power(skewness_factor, pidx + 1) for pidx in range(partitions)]
	bin_size_factor = np.floor(num_examples / sum(factorized_bin_sizes))

	# Find actual bin sizes using the smallest bin size (bin_size_factor).
	partitions_sizes = [np.floor(np.power(skewness_factor, pidx + 1) * bin_size_factor) for pidx in range(partitions)]
	total_partitions_size = sum(partitions_sizes)
	if total_partitions_size < num_examples:
		# If partition sizes are not summing up to the total
		# then start assigning remaining examples from the largest
		# bin to the smallest (thus the reversed call).
		new_partitions_sizes = list(reversed(partitions_sizes))
		remaining_examples = num_examples - total_partitions_size
		for idx, psize in enumerate(new_partitions_sizes):
			# new data size depends on factor
			increment = np.ceil(remaining_examples / skewness_factor)
			new_partitions_sizes[idx] = psize + increment
			remaining_examples -= increment
		# Bring the partitions into ascending size order (smallest->largest)
		partitions_sizes = list(new_partitions_sizes)

	partitions_sizes = sorted(partitions_sizes, reverse=descending)

	return partitions_sizes


if __name__ == "__main__":
	TRAINING_DATA_PATH = "../centralized/train.csv"
	data = pd.read_csv(TRAINING_DATA_PATH)
	data["age_bin"] = data["age_at_scan"].astype(int)
	num_examples = len(data.index)

	# (1) Uniform & IID.
	SIZES = [900, 900, 900, 900, 900, 900, 900, 900]
	MEANS = [65, 65, 65, 65, 65, 65, 65, 65]
	STD = 5

	# (2) Uniform & Non-IID.
	# SIZES = [900, 900, 900, 900, 900, 900, 900, 900]
	# MEANS = [50, 60, 70, 80, 50, 60, 70, 80]
	# STD = 7

	# (3) Skewed and Non-IID.
	# SIZES = get_partitions_sizes(skewness_factor=1.35, partitions=8, num_examples=num_examples, descending=True)
	# MEANS = [90, 70, 50, 30, 90, 70, 50, 30]
	# STD = 10

	# Map site idx to list of dataframes.
	node_to_data = defaultdict(list)
	keep_running = True

	# Pass 1
	# Stop when dataset is empty or when all sites are at capacity
	while keep_running and len(data) > 0:

		keep_running = False  # Stop when all sites reach capacity
		for i in range(len(MEANS)):

			if len(data) == 0:
				break

			if len(node_to_data[str(i)]) < SIZES[i]:
				keep_running = True
				# Sample from site (Keep re-sampling till valid)
				while True:
					sample = int(np.random.normal(loc=MEANS[i], scale=STD))

					if len(data[data["age_bin"] == sample]) != 0:
						x = data[data["age_bin"] == sample].iloc[0]
						# Add subj to site set.
						node_to_data[str(i)].append(x)
						# Remove subj from dataset.
						data = data[data["eid"] != x["eid"]]
						break
		print("Remaining Samples: ", len(data))

	for i in range(len(SIZES)):
		print(len(node_to_data[str(i)]), SIZES[i])

	# Pass 2
	# Fill in remaining examples based on descending bin sizes.
	# If there is still remaining data, then add them arbitrarily uniformly across sites.
	while len(data) > 0:
	    for i in range(len(SIZES)):
	        if len(data) > 0:
	            x = data.iloc[0]
	            # Add subj to site set
	            node_to_data[str(i)].append(x)
	            data = data[data["eid"] != x["eid"]]

	federation_data_distribution = defaultdict(dict)
	for sidx in node_to_data:
		fidx = str(int(sidx) + 1)
		pd.DataFrame(node_to_data[sidx]).to_csv("generated_distributions/train_{}.csv".format(fidx), index=False)
		partition_data = [x['age_at_scan'] for x in node_to_data[sidx]]
		print("Partition ID: ", fidx)
		print("Partition Size: ", len(partition_data))
		print("Partition Data: ", partition_data)
		print("Mean: {}, STD: {}".format(np.mean(partition_data), np.std(partition_data)))
		federation_data_distribution[fidx]['train_stats'] = dict()
		federation_data_distribution[fidx]['train_stats']['dataset_size'] = len(partition_data)
		federation_data_distribution[fidx]['train_stats']['dataset_values'] = partition_data
